{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy\n",
    "import pandas\n",
    "import os\n",
    "import hsml\n",
    "from hsml.client.exceptions import RestAPIError"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def setup_env():    \n",
    "\n",
    "    connection = hsml.connection()\n",
    "    mr = connection.get_model_registry()\n",
    "        \n",
    "    return mr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mr = setup_env()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test metrics, description, model_schema and input examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exported_tf_model = mr.get_model(\"model_tf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert isinstance(exported_tf_model, hsml.tensorflow.model.Model)\n",
    "\n",
    "assert 'accuracy' in exported_tf_model.training_metrics and 'loss' in exported_tf_model.training_metrics\n",
    "\n",
    "assert exported_tf_model.description == \"A test desc for this model\"\n",
    "print(exported_tf_model.model_schema)\n",
    "# Check input_example and model schema\n",
    "assert len(exported_tf_model.model_schema['input_schema']['columnar_schema']) == 3, \"schema len incorrect\"\n",
    "assert exported_tf_model.model_schema['output_schema']['tensor_schema']['type'] == \"float64\", \"schema type incorrect\"\n",
    "assert exported_tf_model.model_schema['output_schema']['tensor_schema']['shape'] == '(8,)', \"schema shape incorrect\"\n",
    "assert len(exported_tf_model.input_example) == 3, \"input example len incorrect\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exported_tf_model_v3 = mr.get_model(\"model_tf\", version=3)\n",
    "assert exported_tf_model_v3.version == 3, \"Model version should be 3\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    skl_model = mr.get_model(\"not_found\")\n",
    "    assert False, \"should return RestAPIError\"\n",
    "except RestAPIError:\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "skl_model = mr.get_model(\"model_sklearn\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "skl_model.delete()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf_models = mr.get_models(\"model_tf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert len(tf_models) == 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_tf_model = mr.get_best_model(\"model_tf\", \"accuracy\", \"max\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert best_tf_model.version == 2, \"Highest accuracy should be version 2\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_dir = best_tf_model.download()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert 'saved_model.pb' in os.listdir(model_dir), \"Model file should be in the downloaded model directory\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "PySpark",
   "language": "python",
   "name": "pysparkkernel"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "python",
    "version": 3
   },
   "mimetype": "text/x-python",
   "name": "pyspark",
   "pygments_lexer": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}